In [1]:
#!/usr/bin/env python
# coding: utf-8

#LIBRARIES:
#Tensors:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils

#E(2)-steerable CNNs - librar"y:
from e2cnn import gspaces                                          
from e2cnn import nn as G_CNN   

#Plotting in 2d/3d:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import Ellipse
from matplotlib.colors import Normalize
import matplotlib.cm as cm

#Tools:
import datetime
import sys

#Own files:
import Kernel_and_GP_tools as GP
import My_Tools
import Steerable_CNP_Models as My_Models

#HYPERPARAMETERS:
#Set default as double:
torch.set_default_dtype(torch.float)
#Scale for plotting with plt quiver
quiver_scale=15
if torch.cuda.is_available():
    device = torch.device("cuda:0")  
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")


def SETUP_EXP_5_Cyclic_GP_div_free(Training_par,N=8,batch_size=3):
    
    G_act = gspaces.Rot2dOnR2(N=N)
    feat_type_in=G_CNN.FieldType(G_act, [G_act.irrep(1)])
    GP_train_data_loader,GP_test_data_loader=GP.load_2d_GP_data(Id="37845",batch_size=batch_size)
    GP_parameters={'l_scale':1,'sigma_var':1, 'kernel_type':"div_free",'obs_noise':1e-4,'B':None,'Ker_project':False}
    Operator_par={'train_data_loader': GP_train_data_loader,'test_data_loader': GP_test_data_loader}
    
    #Define the grid, the kernel parameters and the encoder:
    grid_dict={'x_range':[-3,3],'y_range':[-3,3],'n_x_axis':30,'n_y_axis':30}
    kernel_dict_emb={'sigma_var':1,'kernel_type':"rbf",'Ker_project':False}
    encoder=My_Models.Steerable_Encoder(**grid_dict,kernel_dict=kernel_dict_emb,normalize=True)
    
    #Define the kernel parameters for the kernel smoother:
    kernel_dict_out={'sigma_var':1,'kernel_type':"rbf",'B':None,'Ker_project':False}

    #---------------------Conv CNP decoder-----------------------------
    conv_decoder=nn.Sequential(nn.Conv2d(3,6,kernel_size=5,stride=1,padding=2),
              nn.ReLU(),
              nn.Conv2d(6,6,kernel_size=7,stride=1,padding=3),
              nn.ReLU(),
              nn.Conv2d(6,5,kernel_size=5,stride=1,padding=2))

    #---------------------Steerable CNP decoder-----------------------------
    #Define the f||eature types:
    psd_rep,_=My_Tools.get_pre_psd_rep(G_act)
    feat_type_out=G_CNN.FieldType(G_act,[G_act.irrep(1),psd_rep])
    feat_types=[G_CNN.FieldType(G_act, [G_act.trivial_repr,G_act.irrep(1)]),
                G_CNN.FieldType(G_act, 2*[G_act.regular_repr]),
                G_CNN.FieldType(G_act,2*[G_act.regular_repr]),
                feat_type_out]

    #Define the kernel sizes:
    kernel_sizes=[5,7,5]
    geom_decoder=My_Models.Steerable_Decoder(feat_types,kernel_sizes)
    
    #Get the convcnp:
    conv_cnp=My_Models.Steerable_CNP(feature_in=feat_type_in,dim_cov_est=3,G_act=G_act,encoder=encoder,decoder=conv_decoder,kernel_dict_out=kernel_dict_out)    
    geom_cnp=My_Models.Steerable_CNP(feature_in=feat_type_in,dim_cov_est=3,G_act=G_act,encoder=encoder,decoder=geom_decoder,kernel_dict_out=kernel_dict_out)
    
    #Send the models to the correct devices:
    conv_cnp=conv_cnp.to(device)
    geom_cnp=geom_cnp.to(device)
    
    #Create Operator model:
    Conv_CNP_Operator=My_Models.Steerable_CNP_Operator(conv_cnp,**Training_par,**Operator_par)
    Geom_CNP_Operator=My_Models.Steerable_CNP_Operator(geom_cnp,**Training_par,**Operator_par)
    
    return(Conv_CNP_Operator,Geom_CNP_Operator,GP_parameters)    
 
#Duration of training throughout all experiments:
n_epochs=30
n_iterat=500
train=False
evaluate=True
n_tests=400
#------------------------------------
#-----Experiment 5.1:
#----------------------------------------  
Training_par={'Max_n_context_points':50,'n_epochs':n_epochs,'n_plots':None,'n_iterat_per_epoch':n_iterat,
            'learning_rate':1e-3}    
Conv_CNP,Geom_CNP,GP_parameters=SETUP_EXP_5_Cyclic_GP_div_free(Training_par,N=4,batch_size=1)
if train:
    filename_51="Exp_5_1"
    starttime=datetime.datetime.today()
    print("Start training experiment 5.1: ", starttime)
    loss_Geom_CNP=Geom_CNP.train(filename="Initial_ziz_exp_1507/"+filename_51+"_Steerable_CNP_")
    loss_ConvCNP=Conv_CNP.train(filename="Initial_ziz_exp_1507/"+filename_51+"_Conv_CNP_")
    endtime=datetime.datetime.today()
    print("Duration of training on device: ",device,": ",endtime-starttime)
if evaluate:
    Conv_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_5/Exp_5_1_Conv_CNP__2020_07_16_22_08",map_location=torch.device('cpu')))
    Geom_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_5/Exp_5_1_Steerable_CNP__2020_07_16_22_02",map_location=torch.device('cpu')))
    X,Y=next(iter(Conv_CNP.test_data_loader))
    n_context_points=torch.randint(size=[],low=2,high=Conv_CNP.Max_n_context_points)
    x_context,y_context,x_target,y_target=My_Tools.Rand_Target_Context_Splitter(X[0],Y[0],n_context_points)
    Conv_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=None,title="Exp 5.1: ConvCNP")
    Geom_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=GP_parameters,title="Exp 5.1: SteerCNP")
    print("Exp. 5.1: Log-LL Steer.: ",Geom_CNP.test(n_tests))
    print("Exp. 5.1: Log-LL Conv.: ",Conv_CNP.test(n_tests))
    Conv_CNP.test_equivariance_model(plot=True,title="ConvCNP")
    Geom_CNP.test_equivariance_model(plot=True,title="SteerCNP")
#------------------------------------
#-----Experiment 5.2:
#----------------------------------------  
Training_par={'Max_n_context_points':50,'n_epochs':n_epochs,'n_plots':None,'n_iterat_per_epoch':n_iterat//3,
            'learning_rate':1e-4}    
Conv_CNP,Geom_CNP,GP_parameters=SETUP_EXP_5_Cyclic_GP_div_free(Training_par,N=4,batch_size=3)
if train:
    filename_52="Exp_5_2"
    starttime=datetime.datetime.today()
    print("Start training experiment 5.2: ", starttime)
    loss_Geom_CNP=Geom_CNP.train(filename="Initial_ziz_exp_1507/"+filename_52+"_Steerable_CNP_")
    loss_ConvCNP=Conv_CNP.train(filename="Initial_ziz_exp_1507/"+filename_52+"_Conv_CNP_")
    endtime=datetime.datetime.today()
    print("Duration of training on device: ",device,": ",endtime-starttime)
if evaluate:
    Conv_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_5/Exp_5_2_Conv_CNP__2020_07_16_22_19",map_location=torch.device('cpu')))
    Geom_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_5/Exp_5_2_Steerable_CNP__2020_07_16_22_14",map_location=torch.device('cpu')))
    X,Y=next(iter(Conv_CNP.test_data_loader))
    n_context_points=torch.randint(size=[],low=2,high=Conv_CNP.Max_n_context_points)
    x_context,y_context,x_target,y_target=My_Tools.Rand_Target_Context_Splitter(X[0],Y[0],n_context_points)
    Conv_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=None,title="Exp 5.1: ConvCNP")
    Geom_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=GP_parameters,title="Exp 5.1: SteerCNP")
    print("Exp. 5.1: Log-LL Steer.: ",Geom_CNP.test(n_tests))
    print("Exp. 5.1: Log-LL Conv.: ",Conv_CNP.test(n_tests))
    Conv_CNP.test_equivariance_model(plot=True,title="ConvCNP")
    Geom_CNP.test_equivariance_model(plot=True,title="SteerCNP")
#------------------------------------
#-----Experiment 2.3:
#----------------------------------------  
Training_par={'Max_n_context_points':50,'n_epochs':n_epochs,'n_plots':None,'n_iterat_per_epoch':n_iterat//5,
            'learning_rate':1e-2}    
Conv_CNP,Geom_CNP,GP_parameters=SETUP_EXP_5_Cyclic_GP_div_free(Training_par,N=4,batch_size=5)
if train:
    filename_53="Exp_5_3"
    starttime=datetime.datetime.today()
    print("Start training experiment 5.3: ", starttime)
    loss_Geom_CNP=Geom_CNP.train(filename="Initial_ziz_exp_1507/"+filename_53+"_Steerable_CNP_")
    loss_ConvCNP=Conv_CNP.train(filename="Initial_ziz_exp_1507/"+filename_53+"_Conv_CNP_")
    endtime=datetime.datetime.today()
    print("Duration of training on device: ",device,": ",endtime-starttime)
if evaluate:
    Conv_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_5/Exp_5_3_Conv_CNP__2020_07_16_22_30",map_location=torch.device('cpu')))
    Geom_CNP.load_state_dict(torch.load("Trained_Models/Initial_ziz_exp_1507/Exp_5/Exp_5_3_Steerable_CNP__2020_07_16_22_25",map_location=torch.device('cpu')))
    X,Y=next(iter(Conv_CNP.test_data_loader))
    n_context_points=torch.randint(size=[],low=2,high=Conv_CNP.Max_n_context_points)
    x_context,y_context,x_target,y_target=My_Tools.Rand_Target_Context_Splitter(X[0],Y[0],n_context_points)
    Conv_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=None,title="Exp 5.1: ConvCNP")
    Geom_CNP.plot_test(x_context,y_context,x_target,y_target,GP_parameters=GP_parameters,title="Exp 5.1: SteerCNP")
    print("Exp. 5.1: Log-LL Steer.: ",Geom_CNP.test(n_tests))
    print("Exp. 5.1: Log-LL Conv.: ",Conv_CNP.test(n_tests))
    Conv_CNP.test_equivariance_model(plot=True,title="ConvCNP")
    Geom_CNP.test_equivariance_model(plot=True,title="SteerCNP")
Running on the CPU
Exp. 5.1: Log-LL Steer.:  -1.9176439046859741
Exp. 5.1: Log-LL Conv.:  -2.069878578186035
Exp. 5.1: Log-LL Steer.:  -2.4562535285949707
Exp. 5.1: Log-LL Conv.:  -2.337111473083496
Exp. 5.1: Log-LL Steer.:  -2.2246949672698975
Exp. 5.1: Log-LL Conv.:  -2.604445457458496
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>